Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds AdeMAMix Optimizer to contrib #1104

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open

Conversation

mathDR
Copy link

@mathDR mathDR commented Oct 14, 2024

This PR adds the AdaMAMix optimizer from The arxiv preprint: THE ADEMAMIX OPTIMIZER:
BETTER, FASTER, OLDER

Closes #1058

The docs have been updated, along with the relevant files.

Currently the docstrings are implemented, but further descriptions should/could be added. I will reach out to the paper authors to assist with that (if they are willing).

docs/api/contrib.rst Outdated Show resolved Hide resolved
optax/contrib/_ademamix.py Outdated Show resolved Hide resolved
class ScaleByAdemamixState(NamedTuple):
"""State for the Ademamix algorithm."""

count: chex.Array

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here about the shape of count is useful (since it's very specific) and commonplace in the code: https://github.com/google-deepmind/optax/blob/main/optax/_src/transform.py#L88

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice. I see it is

count: chex.Array  # shape=(), dtype=jnp.int32.

everywhere.

alpha_scheduler: Optional[base.ScalarOrSchedule] = None,
eps: float = 1e-8,
) -> base.GradientTransformation:
"""Rescale updates according to the Ademamix algorithm.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A brief description of AdEMAMix and how it actually operates may be useful. For example, a comparison to Adam (on which it is based).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, pseudo-code resembling https://github.com/google-deepmind/optax/blob/main/optax/_src/alias.py#L586-L594 may be helpful.


Args:
b1: Exponential decay rate to track the first moment of past gradients for
the first Exponential Moving Average (EMA) - same as AdamW

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is essentially duplicate information relative to "Exponential decay rate". I'd suggest mentioning EMA in a general description above, and then making the arg information resemble that of AdamW.

b3: float = 0.9999,
alpha: float = 5.0,
b3_scheduler: Optional[base.ScalarOrSchedule] = None,
alpha_scheduler: Optional[base.ScalarOrSchedule] = None,
Copy link

@zcharles8 zcharles8 Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make alpha (and b3) a base.ScalarOrSchedule instead?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed the patter with the majority of the other classes that did not use base.ScalarOrSchedule as the type and tried to replicate that.

Of course, the authors in the paper specifically recommend scheduling b3 and alpha because of early training instabilities, so this is a bit of a different use case.

I will make this change.

"""State for the Ademamix algorithm."""

count: chex.Array
count_m2: chex.Array

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why you have multiple counts. Would just the first suffice?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the first count replicates the same entitiy in adam. The second is for keeping track of the second momentum, so the user might want to do different things with each.

Does that make sense?

I will make an inline comment to this fact

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not convinced that there's utility in having separate counts - they're incremented identically in the actual code. I would lean towards having a single count, but I don't think this is blocking.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super cool work - reproducing results is definitely under-appreciated. This is really nice!

m1 = otu.tree_update_moment(
updates, state.m1, b1, 1
) # m1 = b1 * m1 + (1-b1) * updates
m2 = otu.tree_update_moment(updates, state.m2, c_b3, 1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this isn't done elsewhere in optax, but generally I prefer opaque constants (e.g. 1 in this case) to be passed in via kwargs. Consider this optional though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so:

m2 = otu.tree_update_moment(updates, state.m2, c_b3, order=1)

is what you are describing?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly.

optax/contrib/_ademamix.py Show resolved Hide resolved
@mathDR
Copy link
Author

mathDR commented Oct 18, 2024

Hey @zcharles8 in rereading the paper I subsequently discovered the authors have a full jax implemenentation here.

I will email them and see if they have any qualms about me using aspects of their code as part of this PR.

Do you know if there would be licensing problems with this approach?

Please advise.

@vroulet
Copy link
Collaborator

vroulet commented Oct 18, 2024

Do you know if there would be licensing problems with this approach?

Great question, probably there would from reading Apple's license... Let me send a mail to one of the authors I know.

@mathDR
Copy link
Author

mathDR commented Oct 18, 2024

Thanks @vroulet I also tried sending an email to all 3 authors (but had to google their contact info so those might be outdated)

Basically I just asked if they are keen to get a version into contrib I would continue my PR but use their docstrings (with attribution)

But if they don't want a version in optax and want users to use their version, I would close the PR.

@mathDR mathDR requested a review from zcharles8 October 22, 2024 01:43
@mathDR
Copy link
Author

mathDR commented Oct 22, 2024

Emailed with Matteo Pagliardini and he wrote:

Very happy to hear that you found the work useful and thanks for the PR. Having AdEMAMix as part of Optax would be great. Feel free to proceed with the PR and use our docstrings. 

So I went ahead and completed the PR.

One open question: AdEMAMix uses a pretty bespoke scheduler for b3. The alpha scheduler can be implemented via the vanilla linear_schedule but I couldn't find a drop in replacement for the b3 scheduler.

Both are now used in the rosenbrock example, but I didn't know if we wanted to add a new scheduler type for completeness?

import chex
import jax.numpy as jnp
import jax.tree_util as jtu
from optax._src import base
from optax._src import combine
from optax._src import numerics
from optax._src import transform

import optax.tree_utils as otu
from typing import NamedTuple, Tuple

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the convention in this library is that imports from python generally go at the top, so

from typing import...

import chex
...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! Good point. (I was going to make a meta point about this repo not having any pre-commit hooks but...).

@zcharles8
Copy link

One open question: AdEMAMix uses a pretty bespoke scheduler for b3. The alpha scheduler can be implemented via the vanilla linear_schedule but I couldn't find a drop in replacement for the b3 scheduler.

Both are now used in the rosenbrock example, but I didn't know if we wanted to add a new scheduler type for completeness?

I think that adding the schedulers to the library (alpha_scheduler and b3_scheduler) is totally reasonable, if for no other reason than to make it easy for someone to directly use the schedulers from the paper.

FWIW I think that this PR is really good (ie. things like typing information) and would prefer to have it pushed through (nice work @mathDR !)

Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. Thank you @mathDR ! And thank you @zcharles8 for the review!

I've left a few minor comments. Mostly, it would be good to clean up the docstring of scale_by_ademamix now that you did the docstring of ademamix. And you;ll be able to resolve any remaining comments that way.

Thank you again to both of you!

optax/contrib/_ademamix.py Show resolved Hide resolved
eps: float = 1e-8,
eps_root: float = 0.0,
) -> base.GradientTransformation:
"""Rescale updates according to the Ademamix algorithm.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update this docstring too? The ademamix docstring looks great. You may simply refer here to the docstring of ademamix for e.g. the pseudocode and copy-paste e.g. the descriptions of the arguments.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to clarify: just a link to the ademamix docstring? That is, don't follow the convention of:

  Description

  References:

  Args:

  Returns:
    A `GradientTransformation` object.

?

Note: the above convention is prevalent throughout optax/_src/transform.py for similar scale_by_X functions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the following docstring

  """Scale updates according to the Ademamix algorithm.

  See :func:`optax.contrib.ademamix.` for a full description of the algorithm.

  References:
    Pagliardini et al, `The AdEMAMix Optimizer: Better, Faster, Older
    <https://arxiv.org/abs/2409.03137>`_, 2024

  Args:
    learning_rate: A global scaling factor, either fixed or evolving along
      iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
    b1: Exponential decay rate to track the fast EMA.
    b2: Exponential decay rate to track the second moment of past gradients.
    b3: Exponential decay rate to track the slow EMA.
    alpha: Mixing coefficient in the linear combination fo the fast and 
      slow EMAs. 
    eps: A small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.
    eps_root: A small constant applied to denominator inside the square root (as
      in RMSProp), to avoid dividing by zero when rescaling. This is needed for
      instance when computing (meta-)gradients through Adam.
    mu_dtype: Optional `dtype` to be used for the first order accumulator; if
      `None` then the `dtype` is inferred from `params` and `updates`.

  Returns:
    The corresponding `GradientTransformation`.
  """

import optax.tree_utils as otu

class ScaleByAdemamixState(NamedTuple):
"""State for the Ademamix algorithm."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a description of the attributes in the docstring?
See e.g.

class ScaleByLBFGSState(NamedTuple):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a stab. LMK if there are better descriptions I could do. I (theoretically) could add the math update if you believe that would be warranted. I also want to get the description right. nu is still the estimate of the second moment, but now the first moment is a combination of m1 and m2 so I just stated their respective EMA types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me.
I think I would say "fast EMA of first moment" rather than first moment personally (same for "slow EMA" and this would apply in the other docstrings). But if you have stronger preferences I won't argue against.

Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, final comments and then we'll be good to go I think

S_t &\leftarrow (m1_t, m2_t, v_t).
\end{align*}

Limitations: AdEMAMix consists in leveraging very old gradients. Therefore,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the following. (the current formatting won't be well handled by docs).

.. note::

    AdEMAMix consists in leveraging very old gradients. Therefore,
    the method is best suited to settings where the number of iterations is
    important. The paper reports on this effect in Appendix C.1.5, showing how
    smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations
    scenarios. Moreover, retaining gradient information over many thousands of
    steps can pose a problem in domains requiring fast adaptation to a sudden
    distribution shift, or general cases in which the distribution is
    non-stationary.

Also, you may build the docs using (from the folder docs)

make html-noplot

that way you can check whether the docs are well formatted.

eps: float = 1e-8,
eps_root: float = 0.0,
) -> base.GradientTransformation:
"""Rescale updates according to the Ademamix algorithm.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the following docstring

  """Scale updates according to the Ademamix algorithm.

  See :func:`optax.contrib.ademamix.` for a full description of the algorithm.

  References:
    Pagliardini et al, `The AdEMAMix Optimizer: Better, Faster, Older
    <https://arxiv.org/abs/2409.03137>`_, 2024

  Args:
    learning_rate: A global scaling factor, either fixed or evolving along
      iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
    b1: Exponential decay rate to track the fast EMA.
    b2: Exponential decay rate to track the second moment of past gradients.
    b3: Exponential decay rate to track the slow EMA.
    alpha: Mixing coefficient in the linear combination fo the fast and 
      slow EMAs. 
    eps: A small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.
    eps_root: A small constant applied to denominator inside the square root (as
      in RMSProp), to avoid dividing by zero when rescaling. This is needed for
      instance when computing (meta-)gradients through Adam.
    mu_dtype: Optional `dtype` to be used for the first order accumulator; if
      `None` then the `dtype` is inferred from `params` and `updates`.

  Returns:
    The corresponding `GradientTransformation`.
  """

iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the fast EMA.
b2: Exponential decay rate to track the second moment of past gradients.
b3: Exponenital decay rate to track the slow EMA.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: Exponential

import optax.tree_utils as otu

class ScaleByAdemamixState(NamedTuple):
"""State for the Ademamix algorithm."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me.
I think I would say "fast EMA of first moment" rather than first moment personally (same for "slow EMA" and this would apply in the other docstrings). But if you have stronger preferences I won't argue against.

@mathDR
Copy link
Author

mathDR commented Oct 24, 2024

Okay thanks for the tip to render the docs. That allowed me to fix a lot of weirdness (and a LaTeX error!).

I think this is good to go!

Copy link

@zcharles8 zcharles8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I don't think I have approval permissions but just want to say nice work.

Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you again @mathDR !

@mathDR
Copy link
Author

mathDR commented Oct 29, 2024

So @vroulet I think you have to merge it? Or does another authorized user have to do it?

@vroulet
Copy link
Collaborator

vroulet commented Oct 29, 2024

The PR needs to be approved by one of the internal owners of the package (as I did), then copybara automatically syncs the PR with the internal code, produces a snapshot for another maintainer to check and once that other maintainer gives his/her approval the PR is merged. (That's why the PRs often take a bit of time to be merged even after they get approved).

@fabianp
Copy link
Member

fabianp commented Oct 30, 2024

thanks for the contribution! Note that you'll need to edit gallery.rst for your example to display in https://optax.readthedocs.io/en/latest/gallery.html

@mathDR
Copy link
Author

mathDR commented Oct 30, 2024

Okay great. I updated gallery.rst and added a png thumbnail to the images/ directory. I also added a colab link, as I saw the other examples did the same.

Note, when I render the docs locally, my example renders a Keyboard Interrupt that doesn't exist in the original jupyter notebook.

Is this an artifact of sphinx attempting to render the notebook? Hopefully this is a problem with my local env and will not persist in main

@fabianp
Copy link
Member

fabianp commented Oct 30, 2024

that's because your test is taking too long to run. you can either make it quicker, or add it to nb_execution_excludepatterns in https://github.com/google-deepmind/optax/blob/main/docs/conf.py (but then it won't be run as part of the test suite)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature request for the AdeMAMix optimizer.
4 participants